{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import datasets, transforms\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])\n",
    "\n",
    "dataset_train = datasets.MNIST(root = \"./data\",\n",
    "                               transform = transform,\n",
    "                               train = True,\n",
    "                               download = True)\n",
    "\n",
    "dataset_test = datasets.MNIST(root = \"./data\",\n",
    "                              transform = transform,\n",
    "                              train = False)\n",
    "\n",
    "train_load = torch.utils.data.DataLoader(dataset = dataset_train,\n",
    "                                         batch_size = 64,\n",
    "                                         shuffle = True)\n",
    "\n",
    "test_load = torch.utils.data.DataLoader(dataset = dataset_test,\n",
    "                                        batch_size = 64,\n",
    "                                        shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "images, label = next(iter(train_load))\n",
    "print(images.shape)\n",
    "images_example = torchvision.utils.make_grid(images)\n",
    "images_example = images_example.numpy().transpose(1,2,0)\n",
    "mean = [0.5,0.5,0.5]\n",
    "std = [0.5,0.5,0.5]\n",
    "images_example = images_example*std + mean\n",
    "plt.imshow(images_example)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "noisy_images = images_example + 0.5*np.random.randn(*images_example.shape)\n",
    "noisy_images = np.clip(noisy_images, 0., 1.)\n",
    "plt.imshow(noisy_images)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class AutoEncoder(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AutoEncoder, self).__init__()\n",
    "        self.encoder = torch.nn.Sequential(torch.nn.Conv2d(1,64, kernel_size=3, stride=1, padding=1),\n",
    "                                           torch.nn.ReLU(),\n",
    "                                           torch.nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "                                           torch.nn.Conv2d(64,128, kernel_size=3, stride=1, padding=1),\n",
    "                                           torch.nn.ReLU(),\n",
    "                                           torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        \n",
    "        self.decoder = torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode=\"nearest\"),\n",
    "                                           torch.nn.Conv2d(128,64, kernel_size=3, stride=1, padding=1),\n",
    "                                           torch.nn.ReLU(),\n",
    "                                           torch.nn.Upsample(scale_factor=2, mode=\"nearest\"),\n",
    "                                           torch.nn.Conv2d(64,1, kernel_size=3, stride=1, padding=1))\n",
    "        \n",
    "    def forward(self, input):\n",
    "        output = self.encoder(input)\n",
    "        output = self.decoder(output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = AutoEncoder()\n",
    "Use_gpu = torch.cuda.is_available()\n",
    "if Use_gpu:\n",
    "    model = model.cuda()\n",
    "print(model)\n",
    "    \n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "loss_f = torch.nn.MSELoss()\n",
    "epoch_n =5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for epoch in range(epoch_n):\n",
    "    running_loss = 0.0\n",
    "    print(\"Epoch {}/{}\".format(epoch, epoch_n))\n",
    "    print(\"-\"*10)\n",
    "    \n",
    "    for data in train_load:\n",
    "        X_train,_= data\n",
    "        noisy_X_train = X_train + 0.5*torch.randn(X_train.shape)\n",
    "        noisy_X_train = torch.clamp(noisy_X_train, 0., 1.)\n",
    "        X_train, noisy_X_train = Variable(X_train.cuda()),Variable(noisy_X_train.cuda())\n",
    "        \n",
    "        train_pre = model(noisy_X_train)\n",
    "        loss = loss_f(train_pre, X_train)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    running_loss +=loss.data[0]\n",
    "print(\"Loss is:{:.4f}\".format(running_loss/len(dataset_train)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
